{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# script to do experiments described in paper: Deep Interest Evolution Network for Click-Through Rate Prediction\n",
    "\n",
    "## how to run\n",
    "\n",
    "1. Please run prepare_neg.ipynb first."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "SEQ_MAX_LEN = 100 # maximum sequence length\n",
    "BATCH_SIZE = 128\n",
    "EMBEDDING_DIM = 18\n",
    "DNN_HIDDEN_SIZE = [200, 80]\n",
    "DNN_DROPOUT = 0.0\n",
    "TEST_RUN = False\n",
    "EPOCH = 2\n",
    "SEED = 10"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "%matplotlib inline\n",
    "\n",
    "import itertools\n",
    "from collections import Counter, OrderedDict\n",
    "\n",
    "import random\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.optim as optim\n",
    "import torch.nn.functional as F\n",
    "from sklearn.metrics import roc_auc_score\n",
    "\n",
    "from prediction_flow.features import Number, Category, Sequence, Features\n",
    "from prediction_flow.transformers.column import (\n",
    "    StandardScaler, CategoryEncoder, SequenceEncoder)\n",
    "\n",
    "from prediction_flow.pytorch.data import Dataset\n",
    "from prediction_flow.pytorch import WideDeep, DeepFM, DNN, DIN, DIEN, AttentionGroup\n",
    "\n",
    "from prediction_flow.pytorch.functions import fit, predict, create_dataloader_fn"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<torch._C.Generator at 0x7f052fcd4150>"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "random.seed(SEED)\n",
    "np.random.seed(SEED)\n",
    "torch.manual_seed(SEED)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_df = pd.read_csv(\n",
    "    \"./local_train.csv\", sep='\\t')\n",
    "\n",
    "valid_df = pd.read_csv(\n",
    "    \"./local_test.csv\", sep='\\t')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "if TEST_RUN:\n",
    "    train_df = train_df.sample(1000)\n",
    "    valid_df = valid_df.sample(1000)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>label</th>\n",
       "      <th>uid</th>\n",
       "      <th>mid</th>\n",
       "      <th>cat</th>\n",
       "      <th>hist_mids</th>\n",
       "      <th>hist_cats</th>\n",
       "      <th>neg_hist_mids</th>\n",
       "      <th>neg_hist_cats</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>0</td>\n",
       "      <td>AZPJ9LUT0FEPY</td>\n",
       "      <td>B00AMNNTIA</td>\n",
       "      <td>Literature &amp; Fiction</td>\n",
       "      <td>0307744434\u00020062248391\u00020470530707\u00020978924622\u000215...</td>\n",
       "      <td>Books\u0002Books\u0002Books\u0002Books\u0002Books</td>\n",
       "      <td>0786890487\u00020618539069\u0002B001IDZJO0\u00021603421548\u000203...</td>\n",
       "      <td>Books\u0002Books\u0002Books\u0002Books\u0002Books</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>1</td>\n",
       "      <td>AZPJ9LUT0FEPY</td>\n",
       "      <td>0800731603</td>\n",
       "      <td>Books</td>\n",
       "      <td>0307744434\u00020062248391\u00020470530707\u00020978924622\u000215...</td>\n",
       "      <td>Books\u0002Books\u0002Books\u0002Books\u0002Books</td>\n",
       "      <td>B00BEFIHOG\u00021402245270\u00020670031747\u00020615785182\u000214...</td>\n",
       "      <td>Literary\u0002Books\u0002Books\u0002Books\u0002Books</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>0</td>\n",
       "      <td>A2NRV79GKAU726</td>\n",
       "      <td>B003NNV10O</td>\n",
       "      <td>Russian</td>\n",
       "      <td>0814472869\u00020071462074\u00021583942300\u00020812538366\u0002B0...</td>\n",
       "      <td>Books\u0002Books\u0002Books\u0002Books\u0002Baking\u0002Books\u0002Books</td>\n",
       "      <td>B00LQABRTG\u0002087830178X\u00020991543009\u0002071533154X\u000203...</td>\n",
       "      <td>Neuropsychology\u0002Books\u0002Books\u0002Books\u0002Books\u0002Books\u0002...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>1</td>\n",
       "      <td>A2NRV79GKAU726</td>\n",
       "      <td>B000UWJ91O</td>\n",
       "      <td>Books</td>\n",
       "      <td>0814472869\u00020071462074\u00021583942300\u00020812538366\u0002B0...</td>\n",
       "      <td>Books\u0002Books\u0002Books\u0002Books\u0002Baking\u0002Books\u0002Books</td>\n",
       "      <td>1595328149\u00021591797810\u00020451233018\u00020373771355\u000214...</td>\n",
       "      <td>Books\u0002Books\u0002Books\u0002Books\u0002Books\u0002Books\u0002Contempora...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>0</td>\n",
       "      <td>A2GEQVDX2LL4V3</td>\n",
       "      <td>0321334094</td>\n",
       "      <td>Books</td>\n",
       "      <td>0743596870\u00020374280991\u00021439140634\u00020976475731</td>\n",
       "      <td>Books\u0002Books\u0002Books\u0002Books</td>\n",
       "      <td>0316159735\u0002156718359X\u00020786812400\u00020062506110</td>\n",
       "      <td>Books\u0002Books\u0002Books\u0002Books</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   label             uid         mid                   cat  \\\n",
       "0      0   AZPJ9LUT0FEPY  B00AMNNTIA  Literature & Fiction   \n",
       "1      1   AZPJ9LUT0FEPY  0800731603                 Books   \n",
       "2      0  A2NRV79GKAU726  B003NNV10O               Russian   \n",
       "3      1  A2NRV79GKAU726  B000UWJ91O                 Books   \n",
       "4      0  A2GEQVDX2LL4V3  0321334094                 Books   \n",
       "\n",
       "                                           hist_mids  \\\n",
       "0  0307744434\u00020062248391\u00020470530707\u00020978924622\u000215...   \n",
       "1  0307744434\u00020062248391\u00020470530707\u00020978924622\u000215...   \n",
       "2  0814472869\u00020071462074\u00021583942300\u00020812538366\u0002B0...   \n",
       "3  0814472869\u00020071462074\u00021583942300\u00020812538366\u0002B0...   \n",
       "4        0743596870\u00020374280991\u00021439140634\u00020976475731   \n",
       "\n",
       "                                    hist_cats  \\\n",
       "0               Books\u0002Books\u0002Books\u0002Books\u0002Books   \n",
       "1               Books\u0002Books\u0002Books\u0002Books\u0002Books   \n",
       "2  Books\u0002Books\u0002Books\u0002Books\u0002Baking\u0002Books\u0002Books   \n",
       "3  Books\u0002Books\u0002Books\u0002Books\u0002Baking\u0002Books\u0002Books   \n",
       "4                     Books\u0002Books\u0002Books\u0002Books   \n",
       "\n",
       "                                       neg_hist_mids  \\\n",
       "0  0786890487\u00020618539069\u0002B001IDZJO0\u00021603421548\u000203...   \n",
       "1  B00BEFIHOG\u00021402245270\u00020670031747\u00020615785182\u000214...   \n",
       "2  B00LQABRTG\u0002087830178X\u00020991543009\u0002071533154X\u000203...   \n",
       "3  1595328149\u00021591797810\u00020451233018\u00020373771355\u000214...   \n",
       "4        0316159735\u0002156718359X\u00020786812400\u00020062506110   \n",
       "\n",
       "                                       neg_hist_cats  \n",
       "0                      Books\u0002Books\u0002Books\u0002Books\u0002Books  \n",
       "1                   Literary\u0002Books\u0002Books\u0002Books\u0002Books  \n",
       "2  Neuropsychology\u0002Books\u0002Books\u0002Books\u0002Books\u0002Books\u0002...  \n",
       "3  Books\u0002Books\u0002Books\u0002Books\u0002Books\u0002Books\u0002Contempora...  \n",
       "4                            Books\u0002Books\u0002Books\u0002Books  "
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train_df.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>label</th>\n",
       "      <th>uid</th>\n",
       "      <th>mid</th>\n",
       "      <th>cat</th>\n",
       "      <th>hist_mids</th>\n",
       "      <th>hist_cats</th>\n",
       "      <th>neg_hist_mids</th>\n",
       "      <th>neg_hist_cats</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>0</td>\n",
       "      <td>A3BI7R43VUZ1TY</td>\n",
       "      <td>B00JNHU0T2</td>\n",
       "      <td>Literature &amp; Fiction</td>\n",
       "      <td>0989464105\u0002B00B01691C\u00021477809732\u00021608442845</td>\n",
       "      <td>Books\u0002Literature &amp; Fiction\u0002Books\u0002Books</td>\n",
       "      <td>0899576168\u0002B0056ATROO\u00020446600474\u00020615209459</td>\n",
       "      <td>Books\u0002Sleep\u0002Books\u0002Books</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>1</td>\n",
       "      <td>A3BI7R43VUZ1TY</td>\n",
       "      <td>0989464121</td>\n",
       "      <td>Books</td>\n",
       "      <td>0989464105\u0002B00B01691C\u00021477809732\u00021608442845</td>\n",
       "      <td>Books\u0002Literature &amp; Fiction\u0002Books\u0002Books</td>\n",
       "      <td>0373527721\u00020981854524\u00020470404159\u0002B00BWKBSOY</td>\n",
       "      <td>Books\u0002Books\u0002Books\u0002Literature &amp; Fiction</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>0</td>\n",
       "      <td>A2Z3AHJPXG3ZNP</td>\n",
       "      <td>B0072YSPJ0</td>\n",
       "      <td>Literature &amp; Fiction</td>\n",
       "      <td>1478310960\u00021492231452\u00021477603425\u0002B00FRKLA6Q</td>\n",
       "      <td>Books\u0002Books\u0002Books\u0002Urban</td>\n",
       "      <td>B00EQAEA60\u0002B007D64VX6\u0002188547766X\u00021590172477</td>\n",
       "      <td>Literature &amp; Fiction\u0002Quran\u0002Books\u0002Books</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>1</td>\n",
       "      <td>A2Z3AHJPXG3ZNP</td>\n",
       "      <td>B00G4I4I5U</td>\n",
       "      <td>Urban</td>\n",
       "      <td>1478310960\u00021492231452\u00021477603425\u0002B00FRKLA6Q</td>\n",
       "      <td>Books\u0002Books\u0002Books\u0002Urban</td>\n",
       "      <td>1583942475\u00021585678600\u00021570199221\u00020312373090</td>\n",
       "      <td>Books\u0002Books\u0002Books\u0002Books</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>0</td>\n",
       "      <td>A2KDDPJUNWC5CA</td>\n",
       "      <td>0316228532</td>\n",
       "      <td>Books</td>\n",
       "      <td>0141326085\u0002031026622X\u00020316077046\u00020988649179\u000214...</td>\n",
       "      <td>Books\u0002Books\u0002Books\u0002Books\u0002Books</td>\n",
       "      <td>B0077FOPFC\u00021594744106\u0002B00DFGN1DE\u00020972259112\u0002B0...</td>\n",
       "      <td>Ghosts\u0002Books\u0002Erotica\u0002Books\u0002Soups &amp; Stews</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   label             uid         mid                   cat  \\\n",
       "0      0  A3BI7R43VUZ1TY  B00JNHU0T2  Literature & Fiction   \n",
       "1      1  A3BI7R43VUZ1TY  0989464121                 Books   \n",
       "2      0  A2Z3AHJPXG3ZNP  B0072YSPJ0  Literature & Fiction   \n",
       "3      1  A2Z3AHJPXG3ZNP  B00G4I4I5U                 Urban   \n",
       "4      0  A2KDDPJUNWC5CA  0316228532                 Books   \n",
       "\n",
       "                                           hist_mids  \\\n",
       "0        0989464105\u0002B00B01691C\u00021477809732\u00021608442845   \n",
       "1        0989464105\u0002B00B01691C\u00021477809732\u00021608442845   \n",
       "2        1478310960\u00021492231452\u00021477603425\u0002B00FRKLA6Q   \n",
       "3        1478310960\u00021492231452\u00021477603425\u0002B00FRKLA6Q   \n",
       "4  0141326085\u0002031026622X\u00020316077046\u00020988649179\u000214...   \n",
       "\n",
       "                                hist_cats  \\\n",
       "0  Books\u0002Literature & Fiction\u0002Books\u0002Books   \n",
       "1  Books\u0002Literature & Fiction\u0002Books\u0002Books   \n",
       "2                 Books\u0002Books\u0002Books\u0002Urban   \n",
       "3                 Books\u0002Books\u0002Books\u0002Urban   \n",
       "4           Books\u0002Books\u0002Books\u0002Books\u0002Books   \n",
       "\n",
       "                                       neg_hist_mids  \\\n",
       "0        0899576168\u0002B0056ATROO\u00020446600474\u00020615209459   \n",
       "1        0373527721\u00020981854524\u00020470404159\u0002B00BWKBSOY   \n",
       "2        B00EQAEA60\u0002B007D64VX6\u0002188547766X\u00021590172477   \n",
       "3        1583942475\u00021585678600\u00021570199221\u00020312373090   \n",
       "4  B0077FOPFC\u00021594744106\u0002B00DFGN1DE\u00020972259112\u0002B0...   \n",
       "\n",
       "                              neg_hist_cats  \n",
       "0                   Books\u0002Sleep\u0002Books\u0002Books  \n",
       "1    Books\u0002Books\u0002Books\u0002Literature & Fiction  \n",
       "2    Literature & Fiction\u0002Quran\u0002Books\u0002Books  \n",
       "3                   Books\u0002Books\u0002Books\u0002Books  \n",
       "4  Ghosts\u0002Books\u0002Erotica\u0002Books\u0002Soups & Stews  "
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "valid_df.head()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# define features"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "cat_enc = SequenceEncoder(sep=\"\\x02\", min_cnt=1, max_len=SEQ_MAX_LEN)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<prediction_flow.transformers.column.sequence_encoder.SequenceEncoder at 0x7f04ba916080>"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "cat_enc.fit(train_df.hist_cats.values)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "cat_word2idx, cat_idx2word = cat_enc.word2idx, cat_enc.idx2word"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1602\n"
     ]
    }
   ],
   "source": [
    "print(len(cat_word2idx))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "mid_enc = SequenceEncoder(sep=\"\\x02\", min_cnt=1, max_len=SEQ_MAX_LEN)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<prediction_flow.transformers.column.sequence_encoder.SequenceEncoder at 0x7f04ba92a7b8>"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "mid_enc.fit(np.vstack([train_df.mid.values, train_df.hist_mids.values]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "mid_word2idx, mid_idx2word = mid_enc.word2idx, mid_enc.idx2word"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "367984\n"
     ]
    }
   ],
   "source": [
    "print(len(mid_word2idx))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "number_features = []\n",
    "\n",
    "category_features = [\n",
    "    Category('mid',\n",
    "             CategoryEncoder(min_cnt=1, word2idx=mid_word2idx, idx2word=mid_idx2word),\n",
    "             embedding_name='mid'),\n",
    "    Category('cat',\n",
    "             CategoryEncoder(min_cnt=1, word2idx=cat_word2idx, idx2word=cat_idx2word),\n",
    "             embedding_name='cat'),\n",
    "]\n",
    "\n",
    "sequence_features = [\n",
    "    Sequence('hist_mids',\n",
    "             SequenceEncoder(sep=\"\\x02\", min_cnt=1, max_len=SEQ_MAX_LEN,\n",
    "                             word2idx=mid_word2idx, idx2word=mid_idx2word),\n",
    "             embedding_name='mid'),\n",
    "    Sequence('hist_cats',\n",
    "             SequenceEncoder(sep=\"\\x02\", min_cnt=1, max_len=SEQ_MAX_LEN,\n",
    "                             word2idx=cat_word2idx, idx2word=cat_idx2word),\n",
    "             embedding_name='cat')\n",
    "]\n",
    "\n",
    "features, train_loader, valid_loader = create_dataloader_fn(\n",
    "    number_features, category_features, sequence_features, BATCH_SIZE, train_df, 'label', valid_df, 4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "def evaluation(model, df, dataloader):\n",
    "    preds = predict(model, dataloader)\n",
    "    return roc_auc_score(df['label'], preds.ravel())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pytorch_lightning as pl\n",
    "\n",
    "class CoolModel(pl.LightningModule):\n",
    "    def __init__(self):\n",
    "        super(CoolModel, self).__init__()\n",
    "        self.model = DNN(\n",
    "            features,\n",
    "            2,\n",
    "            EMBEDDING_DIM,\n",
    "            DNN_HIDDEN_SIZE,\n",
    "            final_activation='sigmoid',\n",
    "            dropout=DNN_DROPOUT)\n",
    "    \n",
    "    def forward(self, x):\n",
    "        return self.model(x)\n",
    "\n",
    "    def training_step(self, batch, batch_nb):\n",
    "        # REQUIRED\n",
    "        y = batch['label']\n",
    "        y_hat = self.forward(batch)\n",
    "        loss = F.binary_cross_entropy(y_hat, y)\n",
    "        return {\n",
    "            'loss': loss,\n",
    "            'progress_bar':\n",
    "            {'training_loss': loss}}\n",
    "\n",
    "    def validation_step(self, batch, batch_nb):\n",
    "        # OPTIONAL\n",
    "        y = batch['label']\n",
    "        y_hat = self.forward(batch)\n",
    "        loss = F.binary_cross_entropy(y_hat, y)\n",
    "        return {'val_loss': loss}\n",
    "\n",
    "    def validation_end(self, outputs):\n",
    "        # OPTIONAL\n",
    "        avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean()\n",
    "        return {'progress_bar': {'val_loss': avg_loss}}\n",
    "\n",
    "    def configure_optimizers(self):\n",
    "        # REQUIRED\n",
    "        return torch.optim.Adam(self.parameters(), lr=0.003)\n",
    "\n",
    "    @pl.data_loader\n",
    "    def train_dataloader(self):\n",
    "        return train_loader\n",
    "\n",
    "    @pl.data_loader\n",
    "    def val_dataloader(self):\n",
    "        # OPTIONAL\n",
    "        # can also return a list of val dataloaders\n",
    "        return valid_loader"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "gpu available: True, used: True\n",
      "VISIBLE GPUS: 0\n"
     ]
    }
   ],
   "source": [
    "from pytorch_lightning import Trainer\n",
    "\n",
    "model = CoolModel()\n",
    "\n",
    "# most basic trainer, uses good defaults\n",
    "trainer = Trainer(max_nb_epochs=EPOCH, gpus=1)    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  0%|          | 0/5 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "                                 Name         Type Params\n",
      "0                               model          DNN    6 M\n",
      "1                 model.embedding:mid    Embedding    6 M\n",
      "2                 model.embedding:cat    Embedding   28 K\n",
      "3             model.pooling:hist_mids   MaxPooling    0  \n",
      "4             model.pooling:hist_cats   MaxPooling    0  \n",
      "5                           model.mlp          MLP   31 K\n",
      "6               model.mlp._sequential   Sequential   31 K\n",
      "7        model.mlp._sequential.dense0       Linear   14 K\n",
      "8    model.mlp._sequential.batchnorm0  BatchNorm1d  400  \n",
      "9   model.mlp._sequential.activation0         ReLU    0  \n",
      "10       model.mlp._sequential.dense1       Linear   16 K\n",
      "11   model.mlp._sequential.batchnorm1  BatchNorm1d  160  \n",
      "12  model.mlp._sequential.activation1         ReLU    0  \n",
      "13                  model.final_layer       Linear   81  \n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 9433/9433 [01:21<00:00, 317.80it/s, batch_nb=8485, epoch=1, gpu=0, loss=0.596, training_loss=0.571, v_nb=1, val_loss=0.625]"
     ]
    },
    {
     "data": {
      "text/plain": [
       "1"
      ]
     },
     "execution_count": 23,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      "100%|██████████| 9433/9433 [01:40<00:00, 317.80it/s, batch_nb=8485, epoch=1, gpu=0, loss=0.596, training_loss=0.571, v_nb=1, val_loss=0.625]"
     ]
    }
   ],
   "source": [
    "trainer.fit(model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [],
   "source": [
    "score = evaluation(model, valid_df, valid_loader)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "auc: 0.7025202354666744\n"
     ]
    }
   ],
   "source": [
    "print(f'auc: {score}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
